import elf
import sw5insts
import traceback
import struct
import sys
ef = elf.ELF64(open(sys.argv[1], "rb").read())
# ef = elf.ELF64(open("../esmds/obj/sw5/nonbonded_sw5.slave.o", "rb").read())
text1_offset = ef.shdrs['.text1'].sh_offset
text1_addr = ef.shdrs['.text1'].sh_addr
text1_size = ef.shdrs['.text1'].sh_size
for i, s in enumerate(ef.ishdrs):
    if s == ef.shdrs['.text1']:
        text1_shndx = i
# got_offset = ef.shdrs['.got'].sh_offset
# got_addr = ef.shdrs['.got'].sh_addr
#print(ef.shdrs['.got'], file=sys.stderr)
text1_barriers = []
last_j = 0
for i, sym in enumerate(ef.syms):
    if elf.ELF64_ST_TYPE(sym.st_info) == elf.STT_FILE:
        last_j = max(i, last_j)
        for j in range(last_j, len(ef.syms)):
            last_j = j
            if elf.ELF64_ST_TYPE(ef.syms[j].st_info == elf.STT_FUNC) and ef.syms[j].st_shndx == text1_shndx:
                #print(sym.st_name, hex(ef.syms[j].st_value))
                if not len(text1_barriers) or ef.syms[j].st_value != text1_barriers[-1]:
                    text1_barriers.append(ef.syms[j].st_value)
                break
text1_barriers.sort()
import bisect

# sys.exit(0)
t12val = 0
gpval = 0
rsymtab = {}
symtab = {}
for sym in ef.syms:
    rsymtab.setdefault(sym.st_value, []).append(sym)
    symtab[sym.st_name] = sym
relatext1_offset = ef.shdrs[".rela.text1"].sh_offset
relatext1_size = ef.shdrs[".rela.text1"].sh_size

from sw5relocation import SW5Rel
gotrels = {}
last_text1 = None
#Round -1: repair .text1+x based relocation
for i in range(relatext1_offset, relatext1_offset + relatext1_size, 24):
    addr, info, off = struct.unpack("<QQQ", ef.bin[i:i+24])
    target = info >> 32
    reltype = SW5Rel(info & 63)
    if elf.ELF64_ST_TYPE(ef.syms[target].st_info) == elf.STT_SECTION and ef.syms[target].st_shndx == text1_shndx:
        ibar = bisect.bisect_right(text1_barriers, addr) - 1
        off += text1_barriers[ibar] - text1_addr
        ef.bin[i:i+24] = struct.pack("<QQQ", addr, info, off)
for i in range(relatext1_offset, relatext1_offset + relatext1_size, 24):
    addr, info, off = struct.unpack("<QQQ", ef.bin[i:i+24])
    target = info >> 32
    reltype = SW5Rel(info & 63)
    # print(hex(addr), target, off, ef.syms[target].st_value)
    if reltype in [SW5Rel.LITERAL, SW5Rel.GOTTPREL]:
        target_id = target << 32 | off
        if target_id not in gotrels:
            gotrels[target_id] = (len(gotrels), ef.syms[target], off)

text1_got = symtab['text1_got']
text1_got_offset = text1_got.st_value - ef.shdrs['.data'].sh_addr + ef.shdrs['.data'].sh_offset
for idx, sym, off in gotrels.values():
    ef.bin[text1_got_offset + idx * 8:text1_got_offset + idx * 8 + 8] = struct.pack("<Q", sym.st_value + off)

gpval = text1_got.st_value
def disp32_16(disp):
    displo = disp & 0xffff
    disphi = disp >> 16
    if displo > 32767:
        displo -= 65536
        disphi += 1
    return disphi, displo
#round 1: replace all gp based accessing to text1_got
for i in range(relatext1_offset, relatext1_offset + relatext1_size, 24):
    
    addr, info, off = struct.unpack("<QQQ", ef.bin[i:i+24])
    target = info >> 32
    target_id = target << 32 | off
    reltype = SW5Rel(info & 63)
    offset = addr - text1_addr + text1_offset
    # if addr in [0x4ff0410730, 0x4ff041082c]:
    #     print(ef.syms[target], "+", off)
    if reltype == SW5Rel.GPDISP:
        disp = gpval - addr
        disphi, displo = disp32_16(disp)
        insthi = sw5insts.parse_inst(struct.unpack("<I", ef.bin[offset:offset+4])[0])
        loff = offset + 4
        instlo = sw5insts.parse_inst(struct.unpack("<I", ef.bin[loff:loff+4])[0])
        while instlo.op != 0x3e or instlo.ra != 29 or instlo.rb != 29:
            loff += 4
            instlo = sw5insts.parse_inst(struct.unpack("<I", ef.bin[loff:loff+4])[0])
        insthi.disp = disphi
        instlo.disp = displo
        if insthi.rb == 27:
            ef.bin[offset:offset+4] = struct.pack("<I", insthi.inst)
            ef.bin[loff:loff+4] = struct.pack("<I", instlo.inst)
        else:
            ef.bin[offset:offset+4] = struct.pack("<I", 0x43ff15df)
            ef.bin[loff:loff+4] = struct.pack("<I", 0x43ff15df)
        # print(hex(addr), inst.decode())
    elif reltype in [SW5Rel.LITERAL, SW5Rel.GOTTPREL]:
        gotidx = gotrels[target_id][0]
        disp = gotidx * 8
        inst = sw5insts.parse_inst(struct.unpack("<I", ef.bin[offset:offset+4])[0])
        inst.disp = disp
        ef.bin[offset:offset+4] = struct.pack("<I", inst.inst)
    elif reltype in [SW5Rel.GPRELHIGH, SW5Rel.GPRELLOW]:
        target_addr = ef.syms[target].st_value
        disp = target_addr - gpval
        disphi, displo = disp32_16(disp)
        inst = sw5insts.parse_inst(struct.unpack("<I", ef.bin[offset:offset+4])[0])
        inst.disp = disphi if reltype == SW5Rel.GPRELHIGH else displo
        ef.bin[offset:offset+4] = struct.pack("<I", inst.inst)
#round2: replace pv based accessing to bsr
for i in range(relatext1_offset, relatext1_offset + relatext1_size, 24):
    addr, info, off = struct.unpack("<QQQ", ef.bin[i:i+24])
    target = info >> 32
    reltype = SW5Rel(info & 63)
    offset = addr - text1_addr + text1_offset
    if reltype == SW5Rel.LITERAL:
        addrnext, infonext, offnext = struct.unpack("<QQQ", ef.bin[i+24:i+48])
        typenext = SW5Rel(infonext & 63)
        offsetnext = addrnext - text1_addr + text1_offset
        if typenext == SW5Rel.LITUSE:
            print("find a lituse", hex(ef.syms[target].st_value + off), hex(addrnext))
            ef.bin[offset:offset+4] = struct.pack("<I", 0x43ff15df)
            disp = (ef.syms[target].st_value + off - addrnext + 4) >> 2
            brinst = sw5insts.BrInst(0)
            brinst.ra = 26
            brinst.disp = disp
            brinst.op = 0x4
            ef.bin[offsetnext:offsetnext+4] = struct.pack("<I", brinst.inst)
open(sys.argv[2], "wb").write(ef.bin)
# print(symtab["slave___errno_location"])
#print(gotrels)
#print(reltypes)